resources:
  memory: 32+
  accelerators: A100:1
  disk_size: 1024
  disk_tier: best

envs:
  MODEL_SIZE: 7
  HF_TOKEN: <your-huggingface-token> # TODO: Replace with huggingface token

setup: |
  set -ex

  git clone https://github.com/facebookresearch/llama.git || true
  cd ./llama
  pip install -e .
  cd -

  git clone https://github.com/skypilot-org/sky-llama.git || true
  cd sky-llama
  pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
  pip install -r requirements.txt
  pip install -e .
  cd -

  # Download the model weights from the huggingface hub, as the official
  # download script has some problem.
  git config --global credential.helper cache
  sudo apt -y install git-lfs
  pip install transformers
  python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}', add_to_git_credential=True)"
  git clone https://huggingface.co/meta-llama/Llama-2-${MODEL_SIZE}b-chat

  wget https://github.com/tsl0922/ttyd/releases/download/1.7.2/ttyd.x86_64
  sudo mv ttyd.x86_64 /usr/local/bin/ttyd
  sudo chmod +x /usr/local/bin/ttyd

run: |
  cd sky-llama
  ttyd /bin/bash -c "torchrun --nproc_per_node $SKYPILOT_NUM_GPUS_PER_NODE chat.py --ckpt_dir ~/sky_workdir/Llama-2-${MODEL_SIZE}b-chat --tokenizer_path ~/sky_workdir/Llama-2-${MODEL_SIZE}b-chat/tokenizer.model"
